// SPDX-FileCopyrightText: Copyright 2015 Graham Sellers, Richard Wright Jr. and Nicholas Haemel
// SPDX-License-Identifier: MIT

// Code obtained from OpenGL SuperBible, Seventh Edition by Graham Sellers, Richard Wright Jr. and
// Nicholas Haemel. Modified to suit needs.

#version 460 core

#ifdef VULKAN

#define HAS_EXTENDED_TYPES 1
#define BEGIN_PUSH_CONSTANTS layout(push_constant) uniform PushConstants {
#define END_PUSH_CONSTANTS };
#define UNIFORM(n)
#define BINDING_INPUT_BUFFER 0
#define BINDING_OUTPUT_IMAGE 1

#else // ^^^ Vulkan ^^^ // vvv OpenGL vvv

#extension GL_NV_gpu_shader5 : enable
#ifdef GL_NV_gpu_shader5
#define HAS_EXTENDED_TYPES 1
#else
#define HAS_EXTENDED_TYPES 0
#endif
#define BEGIN_PUSH_CONSTANTS
#define END_PUSH_CONSTANTS
#define UNIFORM(n) layout(location = n) uniform
#define BINDING_INPUT_BUFFER 0
#define BINDING_OUTPUT_IMAGE 0

#endif

BEGIN_PUSH_CONSTANTS
UNIFORM(0) uint min_accumulation_base;
UNIFORM(1) uint max_accumulation_base;
UNIFORM(2) uint accumulation_limit;
UNIFORM(3) uint buffer_offset;
END_PUSH_CONSTANTS

#define LOCAL_RESULTS 4
#define QUERIES_PER_INVOC 2048

layout(local_size_x = QUERIES_PER_INVOC / LOCAL_RESULTS) in;

layout(std430, binding = 0) readonly buffer block1 {
    uvec2 input_data[gl_WorkGroupSize.x * LOCAL_RESULTS];
};

layout(std430, binding = 1) writeonly coherent buffer block2 {
    uvec2 output_data[gl_WorkGroupSize.x * LOCAL_RESULTS];
};

layout(std430, binding = 2) coherent buffer block3 {
    uvec2 accumulated_data;
};

shared uvec2 shared_data[gl_WorkGroupSize.x * LOCAL_RESULTS];

uvec2 AddUint64(uvec2 value_1, uvec2 value_2) {
    uint carry = 0;
    uvec2 result;
    result.x = uaddCarry(value_1.x, value_2.x, carry);
    result.y = value_1.y + value_2.y + carry;
    return result;
}

void main(void) {
    uint id = gl_LocalInvocationID.x;
    uvec2 base_value[LOCAL_RESULTS];
    const uvec2 accum = accumulated_data;
    for (uint i = 0; i < LOCAL_RESULTS; i++) {
        base_value[i] = (buffer_offset + id * LOCAL_RESULTS + i) < min_accumulation_base
                            ? accumulated_data
                            : uvec2(0);
    }
    uint work_size = gl_WorkGroupSize.x;
    uint rd_id;
    uint wr_id;
    uint mask;
    uvec2 inputs[LOCAL_RESULTS];
    for (uint i = 0; i < LOCAL_RESULTS; i++) {
        inputs[i] = input_data[buffer_offset + id * LOCAL_RESULTS + i];
    }
    // The number of steps is the log base 2 of the
    // work group size, which should be a power of 2
    const uint steps = uint(log2(work_size)) + uint(log2(LOCAL_RESULTS));
    uint step = 0;

    // Each invocation is responsible for the content of
    // two elements of the output array
    for (uint i = 0; i < LOCAL_RESULTS; i++) {
        shared_data[id * LOCAL_RESULTS + i] = inputs[i];
    }
    // Synchronize to make sure that everyone has initialized
    // their elements of shared_data[] with data loaded from
    // the input arrays
    barrier();
    memoryBarrierShared();
    // For each step...
    for (step = 0; step < steps; step++) {
        // Calculate the read and write index in the
        // shared array
        mask = (1 << step) - 1;
        rd_id = ((id >> step) << (step + 1)) + mask;
        wr_id = rd_id + 1 + (id & mask);
        // Accumulate the read data into our element

        shared_data[wr_id] = AddUint64(shared_data[rd_id], shared_data[wr_id]);
        // Synchronize again to make sure that everyone
        // has caught up with us
        barrier();
        memoryBarrierShared();
    }
    // Add the accumulation
    for (uint i = 0; i < LOCAL_RESULTS; i++) {
        shared_data[id * LOCAL_RESULTS + i] =
            AddUint64(shared_data[id * LOCAL_RESULTS + i], base_value[i]);
    }
    barrier();
    memoryBarrierShared();

    // Finally write our data back to the output buffer
    for (uint i = 0; i < LOCAL_RESULTS; i++) {
        output_data[buffer_offset + id * LOCAL_RESULTS + i] = shared_data[id * LOCAL_RESULTS + i];
    }
    if (id == 0) {
        if (min_accumulation_base >= accumulation_limit + 1) {
            accumulated_data = shared_data[accumulation_limit];
            return;
        }
        uvec2 reset_value = shared_data[max_accumulation_base - 1];
        uvec2 final_value = shared_data[accumulation_limit];
        // Two complements
        reset_value = AddUint64(uvec2(1, 0), ~reset_value);
        accumulated_data = AddUint64(final_value, reset_value);
    }
}